pytorch - nn

Lecture 23

Dr. Colin Rundel

Odds & Ends

Torch models

Implementation details:

  • Models are implemented as a class inheriting from torch.nn.Module

  • Must implement constructor and forward() method

    • __init__() should call parent constructor via super()

      • Use torch.nn.Parameter() to indicate model parameters
    • forward() should implement the model - constants + parameters -> return predictions

Fitting proceedure:

  • For each iteration of solver:

    • Get current predictions via a call to forward() or equivalent.

    • Calculate a (scalar) loss or equivalent

    • Call backward() method on loss

    • Use built-in optimizer (step() and then zero_grad() if necessary)

From last time

class Model(torch.nn.Module):
    def __init__(self, X, y, beta=None):
        super().__init__()
        self.X = X
        self.y = y
        if beta is None:
          beta = torch.zeros(X.shape[1])
        beta.requires_grad = True
        self.beta = torch.nn.Parameter(beta)
        
    def forward(self, X):
        return X @ self.beta
    
    def fit(self, opt, n=1000, loss_fn = torch.nn.MSELoss()):
      losses = []
      for i in range(n):
          loss = loss_fn(
            self(self.X).squeeze(), 
            self.y.squeeze()
          )
          loss.backward()
          opt.step()
          opt.zero_grad()
          losses.append(loss.item())
      
      return losses

What is self(self.X)?

This is (mostly) just short hand for calling self.forward(X) to generate the output tensors from the current value(s) of the parameters.

This is done via the __call__() method in the torch.nn.Module class. __call__() allows python classes to be invoked like functions.


class greet:
  def __init__(self, greeting):
    self.greeting = greeting
  def __call__(self, name):
    return self.greeting + " " + name
hello = greet("Hello")
hello("Jane")
'Hello Jane'
gm = greet("Good morning")
gm("Bob")
'Good morning Bob'

MNIST & Logistic models

MNIST handwritten digits - simplified

from sklearn.datasets import load_digits
digits = load_digits()
X = digits.data
X.shape
(1797, 64)
X[0:2]
array([[ 0.,  0.,  5., 13.,  9.,  1.,  0.,
         0.,  0.,  0., 13., 15., 10., 15.,
         5.,  0.,  0.,  3., 15.,  2.,  0.,
        11.,  8.,  0.,  0.,  4., 12.,  0.,
         0.,  8.,  8.,  0.,  0.,  5.,  8.,
         0.,  0.,  9.,  8.,  0.,  0.,  4.,
        11.,  0.,  1., 12.,  7.,  0.,  0.,
         2., 14.,  5., 10., 12.,  0.,  0.,
         0.,  0.,  6., 13., 10.,  0.,  0.,
         0.],
       [ 0.,  0.,  0., 12., 13.,  5.,  0.,
         0.,  0.,  0.,  0., 11., 16.,  9.,
         0.,  0.,  0.,  0.,  3., 15., 16.,
         6.,  0.,  0.,  0.,  7., 15., 16.,
        16.,  2.,  0.,  0.,  0.,  0.,  1.,
        16., 16.,  3.,  0.,  0.,  0.,  0.,
         1., 16., 16.,  6.,  0.,  0.,  0.,
         0.,  1., 16., 16.,  6.,  0.,  0.,
         0.,  0.,  0., 11., 16., 10.,  0.,
         0.]])
y = digits.target
y.shape
(1797,)
y[0:10]
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

Example digits

Test train split

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.20, shuffle=True, random_state=1234
)
X_train.shape
(1437, 64)
y_train.shape
(1437,)
X_test.shape
(360, 64)
y_test.shape
(360,)
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
lr = LogisticRegression(
  penalty=None
).fit(
  X_train, y_train
)
accuracy_score(y_train, lr.predict(X_train))
1.0
accuracy_score(y_test, lr.predict(X_test))
0.9583333333333334

As Torch tensors

X_train = torch.from_numpy(X_train).float()
y_train = torch.from_numpy(y_train)
X_test = torch.from_numpy(X_test).float()
y_test = torch.from_numpy(y_test)
X_train.shape
torch.Size([1437, 64])
y_train.shape
torch.Size([1437])
X_test.shape
torch.Size([360, 64])
y_test.shape
torch.Size([360])
X_train.dtype
torch.float32
y_train.dtype
torch.int64
X_test.dtype
torch.float32
y_test.dtype
torch.int64

PyTorch Model

class mnist_model(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.beta = torch.nn.Parameter(
          torch.randn(input_dim, output_dim, requires_grad=True)  
        )
        self.intercept = torch.nn.Parameter(
          torch.randn(output_dim, requires_grad=True)  
        )
        
    def forward(self, X):
        return (X @ self.beta + self.intercept).squeeze()
    
    def fit(self, X_train, y_train, X_test, y_test, lr=0.001, n=1000):
      opt = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9) 
      losses = []
      
      for i in range(n):
          opt.zero_grad()
          loss = torch.nn.CrossEntropyLoss()(self(X_train), y_train)
          loss.backward()
          opt.step()
          
          losses.append(loss.item())
      
      return losses

Cross entropy loss

model = mnist_model(64, 10)
l = model.fit(X_train, y_train, X_test, y_test)

Cross entropy loss

From the pytorch documentation:

\[ \ell(x, y)=L=\left\{l_1, \ldots, l_N\right\}^{\top}, \quad l_n=-w_{y_n} \log \frac{\exp \left(x_{n, y_n}\right)}{\sum_{c=1}^C \exp \left(x_{n, c}\right)} \]

\[ \ell(x, y)= \begin{cases}\sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n} \cdot 1\left\{y_n \neq \text { ignore_index }\right\}} l_n, & \text { if reduction }=\text { 'mean' } \\ \sum_{n=1}^N l_n, & \text { if reduction }=\text { 'sum' }\end{cases} \]

Out-of-sample accuracy

model(X_test)
tensor([[-1.0983e+01, -2.0008e+01,
         -1.6966e+01, -1.1108e+01,
         -4.7596e+01, -3.2255e+01,
         -3.8194e+01,  3.8499e+01,
         -2.3000e+01, -3.7694e+01],
        [-6.8701e+01,  2.7801e+01,
         -9.0201e+01, -1.3878e+01,
         -1.9980e+01, -2.5793e+01,
         -7.5506e+01, -4.5115e+01,
          1.8130e+01,  5.1021e+01],
        [-2.9823e+01, -3.8137e+01,
         -1.9869e+01, -2.1973e+01,
         -4.1130e+01, -6.2912e+01,
         -9.2013e+01,  4.0552e+01,
         -3.6866e+01, -2.5443e+01],
        [ 1.0965e+01, -9.5212e+00,
          5.1235e+00,  1.9886e+00,
          3.0938e+01,  8.1608e+00,
          4.9731e+01, -6.1979e+01,
          9.9582e+00, -6.0213e+01],
        [ 5.7148e+01, -2.5446e+01,
         -1.2387e+01, -3.6341e+01,
          1.5746e+01, -2.2035e+01,
          4.8012e-01,  1.9123e+01,
          1.1575e+01, -6.5589e+00],
        [-1.7374e+01,  1.7241e+01,
          4.6331e+01,  1.9936e-01,
         -5.7807e+00, -4.4062e+01,
          4.8822e-01, -7.7437e-01,
         -5.3461e+01, -1.5989e+01],
        [ 3.6291e+00, -1.5089e+01,
         -3.3886e+01, -3.6683e+01,
          4.9672e+01, -2.4007e+01,
          4.4296e+01, -9.4572e+00,
         -4.4221e+01, -6.2859e+01],
        [-1.1917e+01, -7.5823e+00,
          2.6831e+01,  5.3106e+01,
         -6.4875e+01,  4.8146e+00,
         -2.8159e+01, -2.3453e+01,
          6.8259e+00,  1.2164e+01],
        [-1.0152e+01, -1.0208e+01,
         -2.0835e+01, -1.4411e+01,
         -2.4293e+01, -4.1586e+01,
          4.6814e+01, -3.9935e+01,
         -3.2164e+00, -4.5243e+01],
        [-3.9425e+01,  2.2337e+01,
          2.6488e+01,  4.1818e+01,
         -2.6865e+01, -4.0368e+01,
         -5.6509e+01, -8.2241e+00,
          8.7910e+00, -1.8999e+01],
        [-5.1291e+01, -4.4022e+01,
         -2.7472e+01, -6.8412e+00,
         -2.6886e+01, -6.6358e+01,
         -1.6226e+01,  2.1603e+01,
         -4.7202e+01, -8.1746e+01],
        [-6.3773e+01,  1.1413e+01,
         -9.8010e+00,  1.9536e+01,
         -5.4951e+01, -3.2417e+01,
         -3.7432e+01, -2.1155e+01,
          4.2713e+01, -2.2549e+01],
        [-4.4662e+01, -1.3609e+01,
         -1.9359e+01, -1.7326e+01,
         -2.7284e+01, -3.1602e+01,
         -1.8396e+01,  3.3744e+01,
         -3.1244e+01, -3.6754e+01],
        [ 2.3975e+01,  1.6375e+01,
         -8.9384e+00,  1.2496e+01,
          1.4170e+00,  1.9628e+01,
         -2.3724e+01,  8.9986e+00,
          2.4782e+01,  4.8668e+01],
        [-7.2887e+01,  3.2491e+00,
         -4.3919e+01, -2.0338e+01,
          4.2730e+01, -7.4927e+01,
         -4.6325e+00, -4.2972e+01,
         -5.4431e+00, -5.4876e+01],
        [-2.9159e+01, -1.6170e+01,
          3.2072e+01,  5.0037e+01,
         -6.6885e+01,  8.2702e+00,
         -2.6394e+01, -6.0928e+00,
          3.6999e+01,  1.3298e+01],
        [-3.3288e+01,  1.5003e+01,
          3.6782e-01,  2.8657e+00,
         -4.0117e+01, -7.9027e+01,
         -4.1195e+01, -1.1044e+00,
         -9.7067e+00, -3.8750e+01],
        [-3.7865e+01, -2.7485e+01,
         -2.3550e+01,  1.2197e+01,
         -6.9072e+01, -2.6240e+01,
         -6.2424e+01,  1.4435e+01,
         -3.1178e+01, -4.1359e+01],
        [-1.5808e+01,  7.7005e+00,
         -5.5255e+00, -2.4479e+00,
         -8.0713e+01,  1.2930e+01,
         -3.1769e+01, -1.4021e+01,
          5.5261e+01,  1.3289e+01],
        [-3.7390e+01, -6.7502e+00,
         -5.8905e+01, -4.2050e+01,
          8.1528e+01, -7.3395e+01,
          1.6926e+01, -5.1744e+01,
         -9.5693e-01, -4.5775e+01],
        [ 8.6150e+01, -3.8499e+01,
          3.8935e+01, -4.3044e+01,
         -1.7088e+01,  1.4571e+01,
          1.3798e+01,  1.6255e+01,
         -2.3764e+01,  3.6795e+01],
        [-2.7595e+01,  3.2364e+00,
          1.1729e+01,  7.5451e+01,
         -6.2852e+01,  1.2571e-01,
         -3.1686e+01, -5.1163e+01,
          1.9591e+01,  2.2289e+01],
        [ 9.8054e+00, -2.4634e+01,
          2.3899e+00,  1.5314e+01,
         -4.0233e+01,  4.6784e+01,
         -2.3932e+01, -1.6791e+01,
          3.9816e+01,  6.8989e+01],
        [-3.3894e+01,  5.6371e+01,
          2.9164e+01,  2.4246e+01,
         -3.0160e+01, -1.1803e+01,
          2.9049e+01, -1.3833e+01,
          9.7778e+00,  9.5518e+00],
        [-3.8119e+01,  1.0889e+01,
          2.4118e+01,  5.1269e+01,
          7.5642e+00,  1.8002e+01,
         -4.0984e+01,  8.6164e+00,
         -1.3371e+01,  2.5396e+01],
        [-1.9395e+01, -1.2434e+01,
         -1.8457e+01, -2.4513e+01,
         -1.9884e+01, -4.5851e+01,
          5.5021e+01, -6.6469e+01,
         -6.9826e+00, -3.3108e+01],
        [-3.5663e+00, -3.0038e+01,
         -5.7817e+00, -8.7770e+00,
         -2.4644e+01, -6.5059e+00,
          3.3370e+01, -6.7320e+01,
         -6.3216e+00, -5.9966e+01],
        [ 5.6996e+01, -1.0854e+01,
          2.0637e+01, -3.3219e+00,
         -2.5902e+01, -7.7458e+00,
          9.3995e+00, -1.6706e+01,
          1.2495e+01,  1.8916e+01],
        [-3.8560e+00, -4.1797e+01,
         -4.3053e+01, -3.3061e+01,
          5.0190e+00,  3.8283e+01,
         -1.1814e+01,  2.1737e+01,
         -2.3719e+01,  2.7442e+01],
        [ 4.3532e+00, -2.6687e+01,
         -2.7771e+01, -5.5524e+01,
          9.2524e+01, -2.9315e+01,
          2.8819e+01, -3.0699e+01,
         -3.7713e+01, -4.5623e+01],
        ...,
        [-6.6726e+01,  8.8321e+00,
         -5.2034e+00,  6.8852e+01,
         -5.8641e+01,  7.7089e+00,
         -4.1937e+01, -3.1379e+01,
          3.5323e+01, -7.6779e-01],
        [-3.7462e+01, -3.0843e+01,
         -5.3008e+01, -4.7459e+01,
          9.2180e+01, -4.3316e+01,
         -1.7296e+01, -1.9943e+01,
          1.2438e+01, -5.3616e+01],
        [-9.4958e+01,  3.5365e+01,
         -3.4748e+00, -1.4255e+00,
          1.5328e+01, -6.2635e+01,
         -4.4198e+01, -8.0639e+00,
         -9.8105e+00, -3.2063e+01],
        [-1.5232e+01, -6.5207e+00,
         -1.2599e+01, -5.2258e+01,
         -2.3461e+01, -2.8231e+01,
         -3.3888e+01, -7.6782e+00,
          1.4524e+01, -4.3439e+01],
        [-8.1029e+01,  2.3989e+01,
         -1.3485e+00, -1.1932e+00,
          3.5744e+00, -4.5838e+01,
         -3.3249e+01, -3.7776e+00,
         -3.7759e+00, -2.6305e+01],
        [-1.6771e+01, -3.1709e+01,
         -9.7816e+00,  6.2425e+00,
         -4.8710e+01, -1.0582e+01,
         -9.8310e+00, -3.1274e+01,
          4.8900e+01,  2.6955e+01],
        [-2.4689e+01, -1.2266e+01,
         -1.1837e+01,  2.1174e-01,
         -8.0655e+01,  9.0777e+01,
         -4.6780e+01, -6.9694e+00,
          8.1077e+00, -7.0857e-01],
        [ 7.9275e+01, -5.3766e+01,
          4.0466e+01, -2.7958e+01,
         -1.7032e+01,  9.4554e+00,
         -1.2416e+00,  1.7785e+01,
         -7.2788e+00,  2.8780e+01],
        [-6.9010e+00, -1.5549e+01,
         -5.4385e+01, -3.6759e+01,
         -6.3645e+00,  2.3008e+01,
         -5.9797e+01,  2.1162e+01,
          1.7089e+01,  6.3785e+01],
        [-1.5135e+01,  1.3867e+01,
          5.7583e+01,  3.4935e+01,
         -3.1474e+00, -1.7859e+00,
         -1.6855e+01, -3.6023e+01,
         -4.6566e+01, -6.7450e+00],
        [-2.4234e+01, -4.6458e+01,
         -5.4653e+01, -2.5673e+01,
          5.2738e-01, -7.3508e+01,
         -8.6025e+01,  3.5347e+01,
         -1.9576e+01, -3.7265e+01],
        [-6.1378e+01, -5.1832e+00,
          4.9104e+01,  1.5125e+01,
         -5.6675e+01, -4.2243e+01,
         -2.7425e+01, -1.7921e+01,
         -3.7454e+01, -1.0724e+01],
        [-7.7316e+00,  1.2216e+01,
          1.2966e+01,  4.0781e+01,
         -5.1465e+01, -9.7857e+00,
         -2.2215e+01, -1.6616e+01,
         -1.6453e+01,  2.2582e+01],
        [-1.4709e+01, -2.2589e+01,
         -5.0535e+00, -3.3308e+01,
         -9.1225e+01,  4.9168e+01,
         -3.9033e+01,  3.9662e+00,
         -8.4880e+00,  4.8270e+00],
        [-2.4275e+01, -1.3292e+01,
          5.0624e+01,  3.4073e+00,
         -3.9657e+01, -1.7126e+01,
         -3.1335e+01, -1.3191e+01,
         -3.7508e+01, -1.8477e+01],
        [ 6.5589e+00,  1.5932e+00,
         -1.2048e+01, -4.3538e+00,
         -2.4319e+01, -5.7922e+01,
          4.0782e+01, -6.2314e+01,
         -5.3869e+00, -1.8023e+01],
        [-5.2549e+01,  5.7667e+00,
         -2.8723e+00,  5.3425e+01,
         -6.1007e+01, -6.1018e-01,
         -5.4261e+01, -1.5245e+01,
          1.1939e+01,  5.6581e+00],
        [-7.0780e+01,  3.0500e-01,
         -8.9566e+00, -3.9406e+01,
          4.4562e+01, -2.3348e+01,
         -1.2745e+01, -5.2471e+00,
         -9.5060e+00,  1.3753e+00],
        [-8.6350e+01,  2.7579e+01,
         -3.8082e+00, -1.9529e+00,
         -2.5453e+00, -4.8824e+01,
         -3.3622e+01,  2.4087e+00,
         -6.8192e+00, -3.2603e+01],
        [-1.6557e+01, -2.2075e+01,
         -2.1497e+01,  5.8682e+00,
         -3.9492e+01,  6.2002e+01,
         -4.1819e+00, -2.0640e+01,
          4.7597e+00,  4.0438e+01],
        [ 8.5592e+01, -4.0591e+01,
          2.7864e+01, -1.4933e+01,
         -7.9904e+00,  9.1447e+00,
          1.2705e+01, -6.3926e+00,
          3.0333e+00,  1.9845e+01],
        [-2.0023e+01,  8.5516e-01,
         -1.7213e+01,  8.7641e-01,
         -9.3665e+01,  3.4271e+01,
         -5.4252e+01, -8.5454e-01,
          3.4802e+01,  2.3282e+01],
        [-3.2434e+01,  1.9979e+01,
         -4.2918e+01, -3.8029e+01,
          5.1043e+01, -3.9316e+01,
         -2.6696e+01,  1.0152e+01,
         -1.1994e+01,  5.2423e+00],
        [-5.3895e+00, -3.4634e+01,
         -1.7818e+01, -2.2728e+01,
         -2.5532e+01, -1.9375e+01,
          4.6944e+01, -5.4207e+01,
         -2.0725e+01, -6.4649e+01],
        [-2.0494e+01,  7.6246e-02,
         -1.0180e+01,  3.9928e+01,
         -3.6597e+01,  2.3926e+01,
         -1.7435e+01, -2.7441e+01,
          5.1775e+00,  1.1479e+01],
        [-3.0709e+01, -2.8010e+01,
          2.3290e+01,  2.6252e-01,
         -8.2072e+01, -1.9324e+01,
         -2.5542e+01, -4.1387e+01,
         -5.2302e+00, -2.9637e+01],
        [-2.3968e+01, -2.9830e+01,
         -1.3620e+01,  1.9619e+01,
         -6.0419e+01,  1.0207e+02,
         -3.6840e+01, -1.3030e+01,
          1.3754e+01,  4.0373e+01],
        [ 3.6165e+01,  6.8243e+00,
         -3.0811e+01, -9.5144e+00,
         -8.0594e+00,  7.2796e+00,
         -8.8233e+00, -6.4205e+00,
          3.0813e+01,  2.9250e+01],
        [-3.9446e+01, -1.5409e+00,
         -5.4306e+01, -1.7449e+01,
          1.0985e+00, -2.2979e+01,
         -3.6111e+01,  2.8276e+01,
         -2.4273e+01,  1.8006e+01],
        [-5.9601e+01,  5.1991e+00,
         -1.2906e+01,  5.0384e+01,
         -6.3010e+01,  5.3884e+00,
         -2.3827e+01, -3.4212e+01,
          2.1940e+01, -1.7732e+01]],
       grad_fn=<SqueezeBackward0>)
val, index = torch.max(model(X_test), dim=1)
index
tensor([7, 9, 7, 6, 0, 2, 4, 3, 6, 3, 7, 8, 7,
        9, 4, 3, 1, 7, 8, 4, 0, 3, 9, 1, 3, 6,
        6, 0, 5, 4, 1, 2, 1, 2, 3, 2, 7, 6, 1,
        8, 6, 4, 4, 0, 9, 1, 8, 5, 4, 4, 4, 1,
        7, 6, 8, 2, 9, 9, 7, 0, 8, 3, 1, 8, 8,
        1, 3, 9, 9, 3, 9, 6, 9, 5, 8, 1, 9, 2,
        1, 3, 8, 7, 3, 3, 1, 7, 7, 5, 8, 2, 6,
        1, 9, 1, 6, 4, 5, 2, 2, 4, 5, 6, 4, 6,
        5, 9, 2, 4, 1, 0, 7, 6, 1, 2, 9, 5, 2,
        5, 0, 3, 2, 7, 6, 4, 3, 2, 1, 1, 6, 7,
        6, 8, 1, 4, 7, 5, 0, 9, 1, 0, 5, 6, 7,
        6, 3, 8, 3, 2, 0, 4, 0, 1, 5, 4, 6, 1,
        1, 1, 6, 1, 7, 0, 0, 7, 9, 5, 4, 1, 3,
        8, 6, 4, 9, 1, 5, 7, 4, 7, 4, 0, 2, 2,
        1, 1, 4, 6, 3, 5, 5, 9, 4, 5, 5, 9, 3,
        9, 2, 1, 2, 0, 8, 2, 8, 6, 2, 4, 6, 8,
        3, 9, 1, 0, 8, 1, 8, 5, 6, 8, 9, 1, 8,
        0, 4, 9, 7, 0, 5, 5, 6, 1, 3, 0, 5, 8,
        2, 0, 9, 6, 6, 7, 8, 4, 1, 0, 5, 2, 5,
        1, 6, 4, 7, 1, 6, 6, 4, 4, 6, 3, 2, 3,
        2, 6, 5, 2, 9, 7, 7, 0, 1, 0, 4, 3, 1,
        2, 7, 9, 8, 5, 3, 5, 7, 0, 4, 8, 4, 9,
        4, 0, 7, 7, 3, 5, 3, 5, 2, 9, 7, 9, 5,
        2, 7, 4, 3, 9, 1, 7, 9, 8, 5, 0, 6, 0,
        8, 7, 0, 9, 5, 5, 9, 6, 1, 2, 3, 3, 6,
        3, 2, 9, 3, 0, 3, 4, 1, 8, 1, 8, 5, 0,
        9, 2, 7, 2, 3, 5, 2, 6, 3, 4, 1, 5, 0,
        8, 4, 6, 3, 2, 5, 0, 7, 3])
(index == y_test).sum()
tensor(319)
(index == y_test).sum() / len(y_test)
tensor(0.8861)

Calculating Accuracy

class mnist_model(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.beta = torch.nn.Parameter(
          torch.randn(input_dim, output_dim, requires_grad=True)  
        )
        self.intercept = torch.nn.Parameter(
          torch.randn(output_dim, requires_grad=True)  
        )
        
    def forward(self, X):
        return (X @ self.beta + self.intercept).squeeze()
    
    def fit(self, X_train, y_train, X_test, y_test, lr=0.001, n=1000, acc_step=10):
      opt = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9) 
      losses, train_acc, test_acc = [], [], []
      
      for i in range(n):
          opt.zero_grad()
          loss = torch.nn.CrossEntropyLoss()(self(X_train), y_train)
          loss.backward()
          opt.step()
          losses.append(loss.item())
          
          if (i+1) % acc_step == 0:
            val, train_pred = torch.max(self(X_train), dim=1)
            val, test_pred  = torch.max(self(X_test), dim=1)
            
            train_acc.append( (train_pred == y_train).sum() / len(y_train) )
            test_acc.append( (test_pred == y_test).sum() / len(y_test) )
            
      return (losses, train_acc, test_acc)

Performance

loss, train_acc, test_acc = mnist_model(
  64, 10
).fit(
  X_train, y_train, X_test, y_test, acc_step=10, n=3000
)

NN Layers

class mnist_nn_model(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim)
        
    def forward(self, X):
        return self.linear(X)
    
    def fit(self, X_train, y_train, X_test, y_test, lr=0.001, n=1000, acc_step=10):
      opt = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9) 
      losses, train_acc, test_acc = [], [], []
      
      for i in range(n):
          opt.zero_grad()
          loss = torch.nn.CrossEntropyLoss()(self(X_train), y_train)
          loss.backward()
          opt.step()
          losses.append(loss.item())
          
          if (i+1) % acc_step == 0:
            val, train_pred = torch.max(self(X_train), dim=1)
            val, test_pred  = torch.max(self(X_test), dim=1)
            
            train_acc.append( (train_pred == y_train).sum() / len(y_train) )
            test_acc.append( (test_pred == y_test).sum() / len(y_test) )
            
      return (losses, train_acc, test_acc)

NN linear layer

Applies a linear transform to the incoming data (\(X\)): \[y = X A^T+b\]

X.shape
(1797, 64)
model = mnist_nn_model(64, 10)
model.parameters()
<generator object Module.parameters at 0x30ee31b60>
list(model.parameters())[0].shape  # A - weights (betas)
torch.Size([10, 64])
list(model.parameters())[1].shape  # b - bias
torch.Size([10])

Performance

loss, train_acc, test_acc = model.fit(X_train, y_train, X_test, y_test, n=1000)
train_acc[-5:]
[tensor(0.9937), tensor(0.9937), tensor(0.9937), tensor(0.9937), tensor(0.9937)]
test_acc[-5:]
[tensor(0.9611), tensor(0.9611), tensor(0.9611), tensor(0.9611), tensor(0.9611)]

Feedforward Neural Network

FNN Model

class mnist_fnn_model(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, nl_step = torch.nn.ReLU(), seed=1234):
        super().__init__()
        self.l1 = torch.nn.Linear(input_dim, hidden_dim)
        self.nl = nl_step
        self.l2 = torch.nn.Linear(hidden_dim, output_dim)
        
    def forward(self, X):
        out = self.l1(X)
        out = self.nl(out)
        out = self.l2(out)
        return out
    
    def fit(self, X_train, y_train, X_test, y_test, lr=0.001, n=1000, acc_step=10):
      opt = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9) 
      losses, train_acc, test_acc = [], [], []
      
      for i in range(n):
          opt.zero_grad()
          loss = torch.nn.CrossEntropyLoss()(self(X_train), y_train)
          loss.backward()
          opt.step()
          
          losses.append(loss.item())
          
          if (i+1) % acc_step == 0:
            val, train_pred = torch.max(self(X_train), dim=1)
            val, test_pred  = torch.max(self(X_test), dim=1)
            
            train_acc.append( (train_pred == y_train).sum().item() / len(y_train) )
            test_acc.append( (test_pred == y_test).sum().item() / len(y_test) )
            
      return (losses, train_acc, test_acc)

Non-linear activation functions

\[\text{Tanh}(x) = \frac{\exp(x)-\exp(-x)}{\exp(x) + \exp(-x)}\]

\[\text{ReLU}(x) = \max(0,x)\]

Model parameters

model = mnist_fnn_model(64,64,10)
len(list(model.parameters()))
4
for i, p in enumerate(model.parameters()):
  print("Param", i, p.shape)
Param 0 torch.Size([64, 64])
Param 1 torch.Size([64])
Param 2 torch.Size([10, 64])
Param 3 torch.Size([10])

Performance - ReLU

loss, train_acc, test_acc = mnist_fnn_model(64,64,10).fit(
  X_train, y_train, X_test, y_test, n=2000
)
train_acc[-5:]
[0.9979123173277662, 0.9979123173277662, 0.9979123173277662, 0.9979123173277662, 0.9979123173277662]
test_acc[-5:]
[0.975, 0.975, 0.975, 0.975, 0.975]

Performance - tanh

loss, train_acc, test_acc = mnist_fnn_model(64,64,10, nl_step=torch.nn.Tanh()).fit(
  X_train, y_train, X_test, y_test, n=2000
)
train_acc[-5:]
[0.9951287404314544, 0.9951287404314544, 0.9951287404314544, 0.9951287404314544, 0.9951287404314544]
test_acc[-5:]
[0.9722222222222222, 0.9722222222222222, 0.9722222222222222, 0.9722222222222222, 0.9722222222222222]

Adding another layer

class mnist_fnn2_model(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, nl_step = torch.nn.ReLU(), seed=1234):
        super().__init__()
        self.l1 = torch.nn.Linear(input_dim, hidden_dim)
        self.nl = nl_step
        self.l2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.nl = nl_step
        self.l3 = torch.nn.Linear(hidden_dim, output_dim)
        
    def forward(self, X):
        out = self.l1(X)
        out = self.nl(out)
        out = self.l2(out)
        out = self.nl(out)
        out = self.l3(out)
        return out
    
    def fit(self, X_train, y_train, X_test, y_test, lr=0.001, n=1000, acc_step=10):
      loss_fn = torch.nn.CrossEntropyLoss()
      opt = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9) 
      losses, train_acc, test_acc = [], [], []
      
      for i in range(n):
          opt.zero_grad()
          loss = loss_fn(self(X_train), y_train)
          loss.backward()
          opt.step()
          
          losses.append(loss.item())
          
          if (i+1) % acc_step == 0:
            val, train_pred = torch.max(self(X_train), dim=1)
            val, test_pred  = torch.max(self(X_test), dim=1)
            
            train_acc.append( (train_pred == y_train).sum().item() / len(y_train) )
            test_acc.append( (test_pred == y_test).sum().item() / len(y_test) )
            
      return (losses, train_acc, test_acc)

Performance - relu

loss, train_acc, test_acc = mnist_fnn2_model(
  64,64,10, nl_step=torch.nn.ReLU()
).fit(
  X_train, y_train, X_test, y_test, n=1000
)
train_acc[-5:]
[0.9874739039665971, 0.9874739039665971, 0.9874739039665971, 0.9874739039665971, 0.9874739039665971]
test_acc[-5:]
[0.9638888888888889, 0.9638888888888889, 0.9638888888888889, 0.9638888888888889, 0.9638888888888889]

Performance - tanh

loss, train_acc, test_acc = mnist_fnn2_model(
  64,64,10, nl_step=torch.nn.Tanh()
).fit(
  X_train, y_train, X_test, y_test, n=1000
)
train_acc[-5:]
[0.9832985386221295, 0.9832985386221295, 0.9839944328462074, 0.9839944328462074, 0.9853862212943633]
test_acc[-5:]
[0.9611111111111111, 0.9638888888888889, 0.9638888888888889, 0.9666666666666667, 0.9666666666666667]

Convolutional NN

2d convolutions

nn.Conv2d()

cv = torch.nn.Conv2d(
  in_channels=1, out_channels=4, 
  kernel_size=3, 
  stride=1, padding=1
)
list(cv.parameters())[0] # kernel weights
Parameter containing:
tensor([[[[ 1.8830e-03, -1.4839e-02,
            8.2560e-03],
          [ 3.2457e-01,  3.1790e-01,
           -9.0510e-03],
          [ 2.2771e-01,  2.3472e-01,
           -2.7522e-01]]],

        [[[ 1.8235e-01,  1.6332e-05,
            3.3145e-01],
          [ 2.4308e-01,  5.7604e-03,
           -2.4588e-01],
          [-2.6864e-01,  2.9254e-01,
            2.4063e-01]]],

        [[[-9.5288e-02, -2.5502e-01,
           -2.0350e-01],
          [ 7.5498e-02, -2.9131e-02,
           -1.0224e-01],
          [ 5.0378e-02,  1.0951e-01,
           -2.4827e-01]]],

        [[[-1.9948e-01,  1.6291e-01,
            1.7735e-01],
          [ 2.6702e-01, -4.2565e-02,
           -2.4433e-02],
          [-1.7569e-01, -1.6232e-01,
            2.5236e-01]]]],
       requires_grad=True)
list(cv.parameters())[1] # biases
Parameter containing:
tensor([-0.2965, -0.1130,  0.1528,  0.1705],
       requires_grad=True)

Applying Conv2d()

X_train[[0]]
tensor([[ 0.,  0.,  0., 10., 11.,  0.,  0.,
          0.,  0.,  0.,  9., 16.,  6.,  0.,
          0.,  0.,  0.,  0., 15., 13.,  0.,
          0.,  0.,  0.,  0.,  0., 14., 10.,
          0.,  0.,  0.,  0.,  0.,  1., 15.,
         12.,  8.,  2.,  0.,  0.,  0.,  0.,
         12., 16., 16., 16., 10.,  1.,  0.,
          0.,  7., 16., 12., 12., 16.,  4.,
          0.,  0.,  0.,  9., 15., 12.,  5.,
          0.]])
X_train[[0]].shape
torch.Size([1, 64])
cv(X_train[[0]])
RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [1, 64]
X_train[[0]].view(1,8,8)
tensor([[[ 0.,  0.,  0., 10., 11.,  0.,  0.,
           0.],
         [ 0.,  0.,  9., 16.,  6.,  0.,  0.,
           0.],
         [ 0.,  0., 15., 13.,  0.,  0.,  0.,
           0.],
         [ 0.,  0., 14., 10.,  0.,  0.,  0.,
           0.],
         [ 0.,  1., 15., 12.,  8.,  2.,  0.,
           0.],
         [ 0.,  0., 12., 16., 16., 16., 10.,
           1.],
         [ 0.,  0.,  7., 16., 12., 12., 16.,
           4.],
         [ 0.,  0.,  0.,  9., 15., 12.,  5.,
           0.]]])
cv(X_train[[0]].view(1,8,8))
tensor([[[ -0.2965,  -2.7735,  -2.6780,
            6.9365,  11.4978,   4.6400,
           -0.2965,  -0.2965],
         [ -0.2965,  -4.5062,   2.4452,
           14.0662,   9.6199,   1.6716,
           -0.2965,  -0.2965],
         [ -0.2965,  -4.2110,   4.8867,
           14.0689,   6.1411,  -0.2852,
           -0.2965,  -0.2965],
         [ -0.5717,  -4.1929,   4.3942,
           11.2924,   7.0335,   1.9946,
            0.1589,  -0.2965],
         [ -0.3056,  -3.3014,   2.9759,
           10.2769,   9.1377,   7.5826,
            6.0680,   2.2154],
         [ -0.2882,  -2.2226,   0.4914,
           10.5030,  12.9160,  11.0236,
           13.4575,   7.8494],
         [ -0.2965,  -0.2608,  -0.7389,
            4.8547,  10.7951,  11.9998,
           12.4447,   7.3108],
         [ -0.2965,  -0.2387,  -0.3497,
            2.3037,   7.2356,   8.3182,
            5.0060,   1.2971]],

        [[ -0.1130,   2.0526,   3.9110,
            0.9464,  -0.1619,   0.9490,
           -0.1130,  -0.1130],
         [ -0.1130,   1.2834,   6.8354,
            4.1111,   2.1422,   3.3513,
           -0.1130,  -0.1130],
         [ -0.1130,   2.5506,   8.5820,
            6.4026,   3.2783,   0.9811,
           -0.1130,  -0.1130],
         [  0.1276,   5.3183,   8.8248,
            7.4890,   4.2862,  -1.6771,
           -0.6503,  -0.1130],
         [ -0.3589,   3.7323,   7.9411,
            9.4952,   8.4141,   4.6317,
           -0.7592,  -2.5069],
         [  0.2184,   3.5926,   6.0798,
           10.0366,   4.8852,   7.0052,
            6.3721,  -0.8046],
         [ -0.1130,   2.1432,   3.4622,
           12.4640,  13.9737,   5.8553,
            3.4008,   4.2796],
         [ -0.1130,   2.2071,   2.9774,
            1.5047,   6.1057,   9.8645,
            6.3470,   4.0200]],

        [[  0.1528,  -2.0816,  -3.8562,
           -0.5471,   2.0505,   1.2856,
            0.1528,   0.1528],
         [  0.1528,  -4.4913,  -5.3649,
           -2.8566,  -1.9172,  -0.4423,
            0.1528,   0.1528],
         [  0.1528,  -6.6879,  -8.1139,
           -3.4520,  -1.4167,  -0.4189,
            0.1528,   0.1528],
         [ -0.0954,  -7.9454,  -9.0344,
           -3.7425,   0.6532,   0.7749,
            0.2536,   0.1528],
         [  0.0506,  -7.2380,  -9.6989,
           -5.3820,  -1.7456,   0.7741,
            1.9567,   0.7661],
         [ -0.0507,  -6.1194, -11.4009,
           -8.0351,  -5.1908,  -3.4536,
            2.1403,   2.1228],
         [  0.1528,  -3.0048, -10.2376,
          -12.2298,  -9.9597,  -7.7381,
           -2.9425,   0.2883],
         [  0.1528,  -1.2716,  -5.8085,
           -8.8323,  -7.8584,  -7.0352,
           -5.1247,  -2.0144]],

        [[  0.1705,   2.4417,   2.5030,
           -3.1881,  -1.4124,   2.0536,
            0.1705,   0.1705],
         [  0.1705,   3.7360,   2.0158,
            0.5805,   1.7008,  -0.4216,
            0.1705,   0.1705],
         [  0.1705,   4.9332,   3.7692,
            1.4150,  -0.3293,  -1.0264,
            0.1705,   0.1705],
         [  0.4229,   6.1117,   4.4972,
            0.0445,  -2.6546,  -1.5596,
           -0.1809,   0.1705],
         [  0.1461,   5.2726,   5.6499,
            1.6383,  -0.3798,  -0.6630,
           -3.4773,  -1.7487],
         [  0.3478,   4.4670,   6.5426,
            1.8857,   0.9046,   2.2291,
           -0.1022,  -0.6621],
         [  0.1705,   2.1276,   6.5453,
            6.4403,   4.9036,   0.3401,
           -1.7091,   1.5623],
         [  0.1705,   1.4119,   3.9285,
            2.7593,   2.5335,   5.9416,
            4.0841,  -1.0344]]],
       grad_fn=<SqueezeBackward1>)

Pooling

x = torch.tensor(
  [[[0,0,0,0],
    [0,1,2,0],
    [0,3,4,0],
    [0,0,0,0]]],
  dtype=torch.float
)
x.shape
torch.Size([1, 4, 4])
torch.nn.MaxPool2d(
  kernel_size=2, stride=1
)(x)
tensor([[[1., 2., 2.],
         [3., 4., 4.],
         [3., 4., 4.]]])
torch.nn.MaxPool2d(
  kernel_size=3, stride=1, padding=1
)(x)
tensor([[[1., 2., 2., 2.],
         [3., 4., 4., 4.],
         [3., 4., 4., 4.],
         [3., 4., 4., 4.]]])
torch.nn.AvgPool2d(
  kernel_size=2
)(x)
tensor([[[0.2500, 0.5000],
         [0.7500, 1.0000]]])
torch.nn.AvgPool2d(
  kernel_size=2, padding=1
)(x)
tensor([[[0.0000, 0.0000, 0.0000],
         [0.0000, 2.5000, 0.0000],
         [0.0000, 0.0000, 0.0000]]])

Convolutional model

class mnist_conv_model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn  = torch.nn.Conv2d(
          in_channels=1, out_channels=8,
          kernel_size=3, stride=1, padding=1
        )
        self.relu = torch.nn.ReLU()
        self.pool = torch.nn.MaxPool2d(kernel_size=2)
        self.lin  = torch.nn.Linear(8 * 4 * 4, 10)
        
    def forward(self, X):
        out = self.cnn(X.view(-1, 1, 8, 8))
        out = self.relu(out)
        out = self.pool(out)
        out = self.lin(out.view(-1, 8 * 4 * 4))
        return out
    
    def fit(self, X_train, y_train, X_test, y_test, lr=0.001, n=1000, acc_step=10):
      loss_fn = torch.nn.CrossEntropyLoss()
      opt = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9) 
      losses, train_acc, test_acc = [], [], []
      
      for i in range(n):
          opt.zero_grad()
          loss = loss_fn(self(X_train), y_train)
          loss.backward()
          opt.step()
          
          losses.append(loss.item())
          
          if (i+1) % acc_step == 0:
            val, train_pred = torch.max(self(X_train), dim=1)
            val, test_pred  = torch.max(self(X_test), dim=1)
            
            train_acc.append( (train_pred == y_train).sum().item() / len(y_train) )
            test_acc.append( (test_pred == y_test).sum().item() / len(y_test) )
            
      return (losses, train_acc, test_acc)

Performance

loss, train_acc, test_acc = mnist_conv_model().fit(
  X_train, y_train, X_test, y_test, n=1000
)
train_acc[-5:]
[0.9937369519832986, 0.9937369519832986, 0.9937369519832986, 0.9937369519832986, 0.9937369519832986]
test_acc[-5:]
[0.9694444444444444, 0.9694444444444444, 0.9694444444444444, 0.9694444444444444, 0.9694444444444444]

Organizing models

class mnist_conv_model2(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = torch.nn.Sequential(
          torch.nn.Unflatten(1, (1,8,8)),
          torch.nn.Conv2d(
            in_channels=1, out_channels=8,
            kernel_size=3, stride=1, padding=1
          ),
          torch.nn.ReLU(),
          torch.nn.MaxPool2d(kernel_size=2),
          torch.nn.Flatten(),
          torch.nn.Linear(8 * 4 * 4, 10)
        )
        
    def forward(self, X):
        return self.model(X)
    
    def fit(self, X_train, y_train, X_test, y_test, lr=0.001, n=1000, acc_step=10):
      opt = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9) 
      losses, train_acc, test_acc = [], [], []
      
      for i in range(n):
          opt.zero_grad()
          loss = torch.nn.CrossEntropyLoss()(self(X_train), y_train)
          loss.backward()
          opt.step()
          
          losses.append(loss.item())
          
          if (i+1) % acc_step == 0:
            val, train_pred = torch.max(self(X_train), dim=1)
            val, test_pred  = torch.max(self(X_test), dim=1)
            
            train_acc.append( (train_pred == y_train).sum() / len(y_train) )
            test_acc.append( (test_pred == y_test).sum() / len(y_test) )
            
      return (losses, train_acc, test_acc)

A bit more on non-linear
activation layers

Non-linear functions

df = pd.read_csv("data/gp.csv")
X = torch.tensor(df["x"], dtype=torch.float32).reshape(-1,1)
y = torch.tensor(df["y"], dtype=torch.float32)

Linear regression

class lin_reg(torch.nn.Module):
    def __init__(self, X):
        super().__init__()
        self.n = X.shape[0]
        self.p = X.shape[1]
        self.model = torch.nn.Sequential(
          torch.nn.Linear(self.p, self.p)
        )
    
    def forward(self, X):
        return self.model(X)
    
    def fit(self, X, y, n=1000):
      losses = []
      opt = torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
      for i in range(n):
          loss = torch.nn.MSELoss()(self(X).squeeze(), y)
          loss.backward()
          opt.step()
          opt.zero_grad()
          losses.append(loss.item())
      
      return losses

Model results

m1 = lin_reg(X)
loss = m1.fit(X,y, n=2000)

Training loss:

Predictions

Double linear regression

class dbl_lin_reg(torch.nn.Module):
    def __init__(self, X, hidden_dim=10):
        super().__init__()
        self.n = X.shape[0]
        self.p = X.shape[1]
        self.model = torch.nn.Sequential(
          torch.nn.Linear(self.p, hidden_dim),
          torch.nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, X):
        return self.model(X)
    
    def fit(self, X, y, n=1000):
      losses = []
      opt = torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
      for i in range(n):
          loss = torch.nn.MSELoss()(self(X).squeeze(), y)
          loss.backward()
          opt.step()
          opt.zero_grad()
          losses.append(loss.item())
      
      return losses

Model results

m2 = dbl_lin_reg(X, hidden_dim=10)
loss = m2.fit(X,y, n=2000)

Training loss:

Predictions

Non-linear regression w/ ReLU

class lin_reg_relu(torch.nn.Module):
    def __init__(self, X, hidden_dim=100):
        super().__init__()
        self.n = X.shape[0]
        self.p = X.shape[1]
        self.model = torch.nn.Sequential(
          torch.nn.Linear(self.p, hidden_dim),
          torch.nn.ReLU(),
          torch.nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, X):
        return self.model(X)
    
    def fit(self, X, y, n=1000):
      losses = []
      opt = torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
      for i in range(n):
          loss = torch.nn.MSELoss()(self(X).squeeze(), y)
          loss.backward()
          opt.step()
          opt.zero_grad()
          losses.append(loss.item())
      
      return losses

Model results

Hidden dimensions

Non-linear regression w/ Tanh

class lin_reg_tanh(torch.nn.Module):
    def __init__(self, X, hidden_dim=10):
        super().__init__()
        self.n = X.shape[0]
        self.p = X.shape[1]
        self.model = torch.nn.Sequential(
          torch.nn.Linear(self.p, hidden_dim),
          torch.nn.Tanh(),
          torch.nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, X):
        return self.model(X)
    
    def fit(self, X, y, n=1000):
      losses = []
      opt = torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
      for i in range(n):
          loss = torch.nn.MSELoss()(self(X).squeeze(), y)
          loss.backward()
          opt.step()
          opt.zero_grad()
          losses.append(loss.item())
      
      return losses

Tanh & hidden dimension

Three layers

class three_layers(torch.nn.Module):
    def __init__(self, X, hidden_dim=100):
        super().__init__()
        self.n = X.shape[0]
        self.p = X.shape[1]
        self.model = torch.nn.Sequential(
          torch.nn.Linear(self.p, hidden_dim),
          torch.nn.ReLU(),
          torch.nn.Linear(hidden_dim, hidden_dim),
          torch.nn.ReLU(),
          torch.nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, X):
        return self.model(X)
    
    def fit(self, X, y, n=1000):
      losses = []
      opt = torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
      for i in range(n):
          loss = torch.nn.MSELoss()(self(X).squeeze(), y)
          loss.backward()
          opt.step()
          opt.zero_grad()
          losses.append(loss.item())
      
      return losses

Model results

Five layers

class five_layers(torch.nn.Module):
    def __init__(self, X, hidden_dim=100):
        super().__init__()
        self.n = X.shape[0]
        self.p = X.shape[1]
        self.model = torch.nn.Sequential(
          torch.nn.Linear(self.p, hidden_dim),
          torch.nn.ReLU(),
          torch.nn.Linear(hidden_dim, hidden_dim),
          torch.nn.ReLU(),
          torch.nn.Linear(hidden_dim, hidden_dim),
          torch.nn.ReLU(),
          torch.nn.Linear(hidden_dim, hidden_dim),
          torch.nn.ReLU(),
          torch.nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, X):
        return self.model(X)
    
    def fit(self, X, y, n=1000):
      losses = []
      opt = torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
      for i in range(n):
          loss = torch.nn.MSELoss()(self(X).squeeze(), y)
          loss.backward()
          opt.step()
          opt.zero_grad()
          losses.append(loss.item())
      
      return losses

Model results